#include <stdio.h>
#include <stdlib.h>

#include "params.h"
#include "types.h"
#include "defs.h"
#include "mem.h"
#include "cpu.h"

uint64 mem_bit_map [256];
byte shadow_memory [MEM_SIZE + 4];
#define PG4K_SHFT 12

typedef struct pg_node{
    uint32 addr;
    struct pg_node *next;
} pg_node_t;

pg_node_t free_pg_head;
pg_node_t alloc_pg_head;

int
mem_pglist_init () {
    uint32 base;
    int i;
    pg_node_t * new_node;
    free_pg_head.next = NULL;
    alloc_pg_head.next = NULL;
    base = 0;
    for (i = 0; i < MEM_PG_NUM; i++) {
        new_node = (pg_node_t*) malloc (sizeof (pg_node_t));
        new_node->addr = base;
        new_node->next = free_pg_head.next;
        free_pg_head.next = new_node;
        base += PG_SIZE;
    }
    printf ("MAX ADDR 0x%X.\n", base);
    return 0;
}

uint32
mem_alloc_pg () {
    pg_node_t *ptr;
    ptr = free_pg_head.next;
    if (ptr == NULL)
        return 0xFFFFFFFF;
    free_pg_head.next = ptr->next;
    ptr->next = alloc_pg_head.next;
    alloc_pg_head.next = ptr;
    return ptr->addr;
}

uint32
mem_free_pg (uint32 addr) {
    pg_node_t *ptr, *ptr1;
    ptr = &alloc_pg_head;
    // printf ("freeing 0x%X.\n", addr);
    //printf ("traversing ...");
    // while (ptr->next != NULL) {
    //     printf ("0x%X\n", ptr->addr);
    //     ptr = ptr->next;
    // }
    //getchar ();
    while (ptr->next != NULL) {
        ptr1 = ptr->next;
        if (ptr1->addr == addr) {
            ptr->next = ptr1->next;
            ptr1->next = free_pg_head.next;
            free_pg_head.next = ptr1;
            // printf ("page 0x%X released.\n", addr);
            //getchar ();
            return 0;
        }
        ptr = ptr->next;
    }
    printf ("0x%X not found.\n", addr);
    getchar ();
    return 0xFFFFFFFF;
}

uint32
pgt_creat (uint32 pgs) {
    uint32 pgt1, pgt2, pg;
    uint32 i, j;
    uint32 va;
    uint32 cnt;
    uint32 pte;
    int stat;
    stat = 0;
    va = 0;
    cnt = 0;
    pgt1 = mem_alloc_pg ();
    for (i = 0; i < 1024; i++) {
        mem_phy_write (pgt1 + i * 4, 0, RW_DWORD);
    }
    for (i = 0; i < pgs; i++) {
        mem_phy_read (pgt1 + L1_PTE_INDEX(va) * 4, &pte, RW_DWORD);
        if (!(pte & PTE_V)) {
            pgt2 = mem_alloc_pg ();
            if (pgt2 == 0xFFFFFFFF) {
                printf ("Unable to allocate pgt2, current pg: %d.\n", i);
            }
            for (j = 0; j < 1024; j++) {
                mem_phy_write (pgt2 + j * 4, 0, RW_DWORD);
            }
            mem_phy_write (pgt1 + L1_PTE_INDEX(va) * 4, pgt2 | PTE_R | PTE_W | PTE_V, RW_DWORD);
        } else {
            pgt2 = pte & PG4K_MASK;
        }
        mem_phy_read (pgt2 + L2_PTE_INDEX(va) * 4, &pte, RW_DWORD);
        if (!(pte & PTE_V)) {
            pg = mem_alloc_pg ();
            cnt++;
            if (pg == 0xFFFFFFFF) {
                printf ("Unable to allocate pg, current pg: %d.\n", i);
                getchar ();
            }
            mem_phy_write (pgt2 + L2_PTE_INDEX(va) * 4, pg | PTE_R | PTE_W | PTE_V, RW_DWORD);
        }
        //printf ("pg: 0x%X.\n", pg);
        va += PG_SIZE;
    }
    printf ("pages created: %u, expected: %u. MAX_ADDR = 0x%X\n", cnt, pgs, cnt * 4096 - 1);
    return pgt1;
}

uint32
pgt_destroy (uint32 pgt1) {
    uint32 i, j;
    uint32 pgt2, pte, pg;
    uint32 t_val;
    pgt1 &= PG4K_MASK;
    for (i = 0; i < 1024; i++) {
        mem_phy_read (pgt1 + i * 4, &pte, RW_DWORD);
        if (pte & PTE_V) {
            pgt2 = pte & PG4K_MASK;
            for (j = 0; j < 1024; j++) {
                mem_phy_read (pgt2 + j * 4, &pte, RW_DWORD);
                if (pte & PTE_V) {
                    pg = pte & PG4K_MASK;
                    mem_free_pg (pg);
                    mem_phy_write (pgt2 + j * 4, 0, RW_DWORD);
                }
            }
            mem_free_pg (pgt2);
            mem_phy_write (pgt1 + i * 4, 0, RW_DWORD);
        }
    }
    mem_free_pg (pgt1);
    return 0;
}

uint32
pgt_test () {
    uint32 pgt_addr[0x4000];
    uint32 i;
    int32 t_val;
    int32 rd;
    #define PRIME_PARAM 119
    printf ("** CHECK MMU **\n");
    printf ("Creating page table...\n");
    pgt_addr[0] = pgt_creat (0x2000 - 1025);
    printf ("Page table created.\n");
    printf ("Checking int8...\n");
    for (i = 0; i < (0x2000 - 1025) * PG_SIZE; i++) {
        rd = random ();
        bv32_mem_virt_write (i, pgt_addr[0], rd, RW_BYTE);
        *((byte*)(shadow_memory +i)) = rd;
    }
    for (i = 0; i < (0x2000 - 1025) * PG_SIZE; i++) {
        bv32_mem_virt_read (i, pgt_addr[0], &t_val, RW_BYTE);
        if ((byte)t_val != *((byte*)(shadow_memory +i))) {
            printf("E:0x%X %d:%d\n", i, (byte)t_val, *((byte*)(shadow_memory + i)));
            getchar ();
        }
    }
    printf ("int8 -- OK.\n");
    //printf ("destroying page table...\n");
    //pgt_destroy (pgt_addr[0]);
    //printf ("Page table destroyed.\n");
    printf ("Creating page table.\n");
    pgt_addr[1] = pgt_creat (0x3000 - 128);
    printf ("Page table created.\n");
    printf ("Checking int16...\n");
    for (i = 0; i < (0x3000 - 128) * PG_SIZE; i+=2) {
        rd = random ();
        bv32_mem_virt_write (i, pgt_addr[1], rd, RW_WORD);
        *((int16*)(shadow_memory + i)) = rd;
    }
    for (i = 0; i < (0x3000 - 128) * PG_SIZE; i+=2) {
        bv32_mem_virt_read (i, pgt_addr[1], &t_val, RW_WORD);
        if ((int16)t_val != *((int16*)(shadow_memory + i))) {
            printf("E:0x%X %d:%d\n", i, (int16)t_val, *((int16*)(shadow_memory + i)));
            getchar ();
        }
    }
    printf ("int16 -- OK.\n");
    printf ("Destroying page table...\n");
    pgt_destroy (pgt_addr[0]);
    printf ("Page table destroyed.\n");
    printf ("Recreating page table.\n");
    pgt_addr[2] = pgt_creat (0x7000 - 128);
    printf ("Page table created.\n");
    printf ("Checking int32...\n");
    for (i = 0; i < (0x7000 - 128) * PG_SIZE; i+=4) {
        rd = random ();
        bv32_mem_virt_write (i, pgt_addr[2], rd, RW_DWORD);
        *((int32*)(shadow_memory + i))  = rd;
    }
    for (i = 0; i < (0x7000 - 128) * PG_SIZE; i+=4) {
        bv32_mem_virt_read (i, pgt_addr[2], &t_val, RW_DWORD);
        if ((int32)t_val != *((int32*)(shadow_memory + i))) {
            printf("E:0x%X %d:%d\n", i, (int32)t_val, *((int32*)(shadow_memory + i)));
            getchar ();
        }
    }
    printf ("int32 -- OK.\n");
    printf ("MMU check finished without error.\n");
}

int
main (int argc ,char ** argv) {
    // int i;
    // if (argc != 2) {
    //     printf ("Usage: %s elf_file\n", argv[0]);
    // }
    // elf32_load_kernel (argv[1], memory, MEM_SIZE);
    // for (i = (0x10000 >> 2); i < (0x10000 >> 2) + 128; i++) {
    //     if (rv32i_tell(((int32*) memory)[i]) != 0)
    //     if (rv32m_tell(((int32*) memory)[i]) != 0)
    //     if (rv32fd_tell(((int32*) memory)[i]) != 0)
    //         printf ("Unknown inst.\n");
    // }
    mem_pglist_init ();
    pgt_test ();
    return 0;
}